#
# Copyright (C) 2002-2006 greg Landrum and Rational Discovery LLC
#
#   @@ All Rights Reserved @@
#  This file is part of the RDKit.
#  The contents are covered by the terms of the BSD license
#  which is included in the file license.txt, found at the root
#  of the RDKit source tree.
#
""" utility functionality for the 2D pharmacophores code

  See Docs/Chem/Pharm2D.triangles.jpg for an illustration of the way
  pharmacophores are broken into triangles and labelled.

  See Docs/Chem/Pharm2D.signatures.jpg for an illustration of bit
  numbering

"""
from __future__ import print_function, division

import itertools

#
#  number of points in a scaffold -> sequence of distances (p1,p2) in
#   the scaffold
#
nPointDistDict = {
  2: ((0, 1), ),
  3: ((0, 1), (0, 2), (1, 2)),
  4: ((0, 1), (0, 2), (0, 3), (1, 2), (2, 3)),
  5: ((0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)),
  6: ((0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (1, 2), (2, 3), (3, 4), (4, 5)),
  7: ((0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6)),
  8: ((0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (1, 2), (2, 3), (3, 4), (4, 5),
      (5, 6), (6, 7)),
  9: ((0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (1, 2), (2, 3), (3, 4),
      (4, 5), (5, 6), (6, 7), (7, 8)),
  10: ((0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (1, 2), (2, 3),
       (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9)),
}

#
#  number of distances in a scaffold -> number of points in the scaffold
#
nDistPointDict = {
  1: 2,
  3: 3,
  5: 4,
  7: 5,
  9: 6,
  11: 7,
  13: 8,
  15: 9,
  17: 10,
}

_trianglesInPharmacophore = {}


def GetTriangles(nPts):
  """ returns a tuple with the distance indices for
   triangles composing an nPts-pharmacophore

  """
  global _trianglesInPharmacophore
  if nPts < 3:
    return []
  res = _trianglesInPharmacophore.get(nPts, [])
  if not res:
    idx1, idx2, idx3 = (0, 1, nPts - 1)
    while idx1 < nPts - 2:
      res.append((idx1, idx2, idx3))
      idx1 += 1
      idx2 += 1
      idx3 += 1
    res = tuple(res)
    _trianglesInPharmacophore[nPts] = res
  return res


def _fact(x):
  if x <= 1:
    return 1

  accum = 1
  for i in range(x):
    accum *= i + 1
  return accum


def BinsTriangleInequality(d1, d2, d3):
  """ checks the triangle inequality for combinations
    of distance bins.

    the general triangle inequality is:
       d1 + d2 >= d3
    the conservative binned form of this is:
       d1(upper) + d2(upper) >= d3(lower)

  """
  if d1[1] + d2[1] < d3[0]:
    return False
  if d2[1] + d3[1] < d1[0]:
    return False
  if d3[1] + d1[1] < d2[0]:
    return False

  return True


def ScaffoldPasses(combo, bins=None):
  """ checks the scaffold passed in to see if all
  contributing triangles can satisfy the triangle inequality

  the scaffold itself (encoded in combo) is a list of binned distances

  """
  # this is the number of points in the pharmacophore
  nPts = nDistPointDict[len(combo)]
  tris = GetTriangles(nPts)
  for tri in tris:
    ds = [bins[combo[x]] for x in tri]
    if not BinsTriangleInequality(ds[0], ds[1], ds[2]):
      return False
  return True


_numCombDict = {}


def NumCombinations(nItems, nSlots):
  """  returns the number of ways to fit nItems into nSlots

    We assume that (x,y) and (y,x) are equivalent, and
    (x,x) is allowed.

    General formula is, for N items and S slots:
      res = (N+S-1)! / ( (N-1)! * S! )

  """
  global _numCombDict
  res = _numCombDict.get((nItems, nSlots), -1)
  if res == -1:
    res = _fact(nItems + nSlots - 1) // (_fact(nItems - 1) * _fact(nSlots))
    _numCombDict[(nItems, nSlots)] = res
  return res


_verbose = 0

_countCache = {}


def CountUpTo(nItems, nSlots, vs, idx=0, startAt=0):
  """ Figures out where a given combination of indices would
   occur in the combinatorial explosion generated by _GetIndexCombinations_

   **Arguments**

     - nItems: the number of items to distribute

     - nSlots: the number of slots in which to distribute them

     - vs: a sequence containing the values to find

     - idx: used in the recursion

     - startAt: used in the recursion

  **Returns**

     an integer

  """
  global _countCache
  if _verbose:
    print('  ' * idx, 'CountUpTo(%d)' % idx, vs[idx], startAt)
  if idx == 0 and (nItems, nSlots, tuple(vs)) in _countCache:
    return _countCache[(nItems, nSlots, tuple(vs))]
  elif idx >= nSlots:
    accum = 0
  elif idx == nSlots - 1:
    accum = vs[idx] - startAt
  else:
    accum = 0
    # get the digit at idx correct
    for i in range(startAt, vs[idx]):
      nLevsUnder = nSlots - idx - 1
      nValsOver = nItems - i
      if _verbose:
        print('  ' * idx, ' ', i, nValsOver, nLevsUnder, NumCombinations(nValsOver, nLevsUnder))
      accum += NumCombinations(nValsOver, nLevsUnder)
    accum += CountUpTo(nItems, nSlots, vs, idx + 1, vs[idx])
  if _verbose:
    print('  ' * idx, '>', accum)
  if idx == 0:
    _countCache[(nItems, nSlots, tuple(vs))] = accum
  return accum


_indexCombinations = {}


def GetIndexCombinations(nItems, nSlots, slot=0, lastItemVal=0):
  """ Generates all combinations of nItems in nSlots without including
    duplicates

  **Arguments**

    - nItems: the number of items to distribute

    - nSlots: the number of slots in which to distribute them

    - slot: used in recursion

    - lastItemVal: used in recursion

  **Returns**

    a list of lists

  """
  global _indexCombinations
  if not slot and (nItems, nSlots) in _indexCombinations:
    res = _indexCombinations[(nItems, nSlots)]
  elif slot >= nSlots:
    res = []
  elif slot == nSlots - 1:
    res = [[x] for x in range(lastItemVal, nItems)]
  else:
    res = []
    for x in range(lastItemVal, nItems):
      tmp = GetIndexCombinations(nItems, nSlots, slot + 1, x)
      for entry in tmp:
        res.append([x] + entry)
    if not slot:
      _indexCombinations[(nItems, nSlots)] = res
  return res


def GetAllCombinations(choices, noDups=1, which=0):
  """  Does the combinatorial explosion of the possible combinations
  of the elements of _choices_.

  **Arguments**

    - choices: sequence of sequences with the elements to be enumerated

    - noDups: (optional) if this is nonzero, results with duplicates,
      e.g. (1,1,0), will not be generated

    - which: used in recursion

  **Returns**

    a list of lists

  >>> GetAllCombinations([(0,),(1,),(2,)])
  [[0, 1, 2]]
  >>> GetAllCombinations([(0,),(1,3),(2,)])
  [[0, 1, 2], [0, 3, 2]]

  >>> GetAllCombinations([(0,1),(1,3),(2,)])
  [[0, 1, 2], [0, 3, 2], [1, 3, 2]]

  """
  if which >= len(choices):
    res = []
  elif which == len(choices) - 1:
    res = [[x] for x in choices[which]]
  else:
    res = []
    tmp = GetAllCombinations(choices, noDups=noDups, which=which + 1)
    for thing in choices[which]:
      for other in tmp:
        if not noDups or thing not in other:
          res.append([thing] + other)
  return res


def GetUniqueCombinations(choices, classes, which=0):
  """  Does the combinatorial explosion of the possible combinations
  of the elements of _choices_.

  """
  #   print(choices, classes)
  assert len(choices) == len(classes)
  if which >= len(choices):
    res = []
  elif which == len(choices) - 1:
    res = [[(classes[which], x)] for x in choices[which]]
  else:
    res = []
    tmp = GetUniqueCombinations(choices, classes, which=which + 1)
    for thing in choices[which]:
      for other in tmp:
        idxThere = 0
        for x in other:
          if x[1] == thing:
            idxThere += 1
        if not idxThere:
          newL = [(classes[which], thing)] + other
          newL.sort()
          if newL not in res:
            res.append(newL)
  return res


def GetUniqueCombinations_new(choices, classes, which=0):
  """  Does the combinatorial explosion of the possible combinations
  of the elements of _choices_.

  """
  #   print(choices, classes)
  assert len(choices) == len(classes)
  combos = set()
  for choice in itertools.product(*choices):
    # If a choice occurs in more than one of the fields, we ignore this case
    if len(set(choice)) != len(choice):
      continue
    combos.add(tuple(sorted((cls, ch) for cls, ch in zip(classes, choice))))
  return [list(combo) for combo in sorted(combos)]


def UniquifyCombinations(combos):
  """ uniquifies the combinations in the argument

    **Arguments**:

      - combos: a sequence of sequences

    **Returns**

      - a list of tuples containing the unique combos

  """
  resD = {}
  for combo in combos:
    k = combo[:]
    k.sort()
    resD[tuple(k)] = tuple(combo)
  return list(resD.values())


def GetPossibleScaffolds(nPts, bins, useTriangleInequality=True):
  """ gets all realizable scaffolds (passing the triangle inequality) with the
     given number of points and returns them as a list of tuples

  """
  if nPts < 2:
    res = 0
  elif nPts == 2:
    res = [(x, ) for x in range(len(bins))]
  else:
    nDists = len(nPointDistDict[nPts])
    combos = GetAllCombinations([range(len(bins))] * nDists, noDups=0)
    res = []
    for combo in combos:
      if not useTriangleInequality or ScaffoldPasses(combo, bins):
        res.append(tuple(combo))
  return res


def OrderTriangle(featIndices, dists):
  """
    put the distances for a triangle into canonical order

    It's easy if the features are all different:
    >>> OrderTriangle([0,2,4],[1,2,3])
    ([0, 2, 4], [1, 2, 3])

    It's trickiest if they are all the same:
    >>> OrderTriangle([0,0,0],[1,2,3])
    ([0, 0, 0], [3, 2, 1])
    >>> OrderTriangle([0,0,0],[2,1,3])
    ([0, 0, 0], [3, 2, 1])
    >>> OrderTriangle([0,0,0],[1,3,2])
    ([0, 0, 0], [3, 2, 1])
    >>> OrderTriangle([0,0,0],[3,1,2])
    ([0, 0, 0], [3, 2, 1])
    >>> OrderTriangle([0,0,0],[3,2,1])
    ([0, 0, 0], [3, 2, 1])

    >>> OrderTriangle([0,0,1],[3,2,1])
    ([0, 0, 1], [3, 2, 1])
    >>> OrderTriangle([0,0,1],[1,3,2])
    ([0, 0, 1], [1, 3, 2])
    >>> OrderTriangle([0,0,1],[1,2,3])
    ([0, 0, 1], [1, 3, 2])
    >>> OrderTriangle([0,0,1],[1,3,2])
    ([0, 0, 1], [1, 3, 2])

  """
  if len(featIndices) != 3:
    raise ValueError('bad indices')
  if len(dists) != 3:
    raise ValueError('bad dists')

  fs = set(featIndices)
  if len(fs) == 3:
    return featIndices, dists

  dSums = [0] * 3
  dSums[0] = dists[0] + dists[1]
  dSums[1] = dists[0] + dists[2]
  dSums[2] = dists[1] + dists[2]
  mD = max(dSums)
  if len(fs) == 1:
    if dSums[0] == mD:
      if dists[0] > dists[1]:
        ireorder = (0, 1, 2)
        dreorder = (0, 1, 2)
      else:
        ireorder = (0, 2, 1)
        dreorder = (1, 0, 2)
    elif dSums[1] == mD:
      if dists[0] > dists[2]:
        ireorder = (1, 0, 2)
        dreorder = (0, 2, 1)
      else:
        ireorder = (1, 2, 0)
        dreorder = (2, 0, 1)
    else:
      if dists[1] > dists[2]:
        ireorder = (2, 0, 1)
        dreorder = (1, 2, 0)
      else:
        ireorder = (2, 1, 0)
        dreorder = (2, 1, 0)
  else:
    # two classes
    if featIndices[0] == featIndices[1]:
      if dists[1] > dists[2]:
        ireorder = (0, 1, 2)
        dreorder = (0, 1, 2)
      else:
        ireorder = (1, 0, 2)
        dreorder = (0, 2, 1)
    elif featIndices[0] == featIndices[2]:
      if dists[0] > dists[2]:
        ireorder = (0, 1, 2)
        dreorder = (0, 1, 2)
      else:
        ireorder = (2, 1, 0)
        dreorder = (2, 1, 0)
    else:  # featIndices[1]==featIndices[2]:
      if dists[0] > dists[1]:
        ireorder = (0, 1, 2)
        dreorder = (0, 1, 2)
      else:
        ireorder = (0, 2, 1)
        dreorder = (1, 0, 2)
  dists = [dists[x] for x in dreorder]
  featIndices = [featIndices[x] for x in ireorder]
  return featIndices, dists


# ------------------------------------
#
#  doctest boilerplate
#
def _runDoctests(verbose=None):  # pragma: nocover
  import sys
  import doctest
  failed, _ = doctest.testmod(optionflags=doctest.ELLIPSIS, verbose=verbose)
  sys.exit(failed)


if __name__ == '__main__':  # pragma: nocover
  _runDoctests()
